# encoding: UTF-8

'''
本文件中包含的是CTA模块的回测引擎，回测引擎的API和CTA引擎一致，
可以使用和实盘相同的代码进行回测。
'''
from __future__ import division

from datetime import datetime, timedelta
from collections import OrderedDict
from itertools import product
import pymongo

from ctaBase import *
from ctaSetting import *

from vtConstant import *
from vtGateway import VtOrderData, VtTradeData
from vtFunction import loadMongoSetting


########################################################################
class BacktestingEngine(object):
	"""
	CTA回测引擎
	函数接口和策略引擎保持一样，
	从而实现同一套代码从回测到实盘。
	"""
	
	TICK_MODE = 'tick'
	BAR_MODE = 'bar'

	#----------------------------------------------------------------------
	def __init__(self):
		"""Constructor"""
		# 本地停止单编号计数
		self.stopOrderCount = 0
		# stopOrderID = STOPORDERPREFIX + str(stopOrderCount)
		
		# 本地停止单字典
		# key为stopOrderID，value为stopOrder对象
		self.stopOrderDict = {}			 # 停止单撤销后不会从本字典中删除
		self.workingStopOrderDict = {}	  # 停止单撤销后会从本字典中删除
		
		# 引擎类型为回测
		self.engineType = ENGINETYPE_BACKTESTING
		
		# 回测相关
		self.strategy = None		# 回测策略
		self.mode = self.BAR_MODE   # 回测模式，默认为K线
		
		self.slippage = 0		   # 回测时假设的滑点
		self.rate = 0			   # 回测时假设的佣金比例（适用于百分比佣金）
		self.size = 1			   # 合约大小，默认为1
		
		self.dbClient = None		# 数据库客户端
		self.dbCursor = None		# 数据库指针
		
		#self.historyData = []	   # 历史数据的列表，回测用
		self.initData = []		  # 初始化用的数据
		#self.backtestingData = []   # 回测用的数据
		
		self.dbName = ''			# 回测数据库名
		self.symbol = ''			# 回测集合名
		
		self.dataStartDate = None	   # 回测数据开始日期，datetime对象
		self.dataEndDate = None		 # 回测数据结束日期，datetime对象
		self.strategyStartDate = None   # 策略启动日期（即前面的数据用于初始化），datetime对象
		
		self.limitOrderDict = OrderedDict()		 # 限价单字典
		self.workingLimitOrderDict = OrderedDict()  # 活动限价单字典，用于进行撮合用
		self.limitOrderCount = 0					# 限价单编号
		
		self.tradeCount = 0			 # 成交编号
		self.tradeDict = OrderedDict()  # 成交字典
		
		self.logList = []			   # 日志记录
		
		# 当前最新数据，用于模拟成交用
		self.tick = None
		self.bar = None
		self.dt = None	  # 最新的时间
		
	#----------------------------------------------------------------------
	def setStartDate(self, startDate='20100416', initDays=10):
		"""设置回测的启动日期"""
		self.dataStartDate = datetime.strptime(startDate, '%Y%m%d')
		
		initTimeDelta = timedelta(initDays)
		self.strategyStartDate = self.dataStartDate + initTimeDelta
		
	#----------------------------------------------------------------------
	def setEndDate(self, endDate=''):
		"""设置回测的结束日期"""
		if endDate:
			self.dataEndDate= datetime.strptime(endDate, '%Y%m%d')
		
	#----------------------------------------------------------------------
	def setBacktestingMode(self, mode):
		"""设置回测模式"""
		self.mode = mode
	
	#----------------------------------------------------------------------
	def setDatabase(self, dbName, symbol):
		"""设置历史数据所用的数据库"""
		self.dbName = dbName
		self.symbol = symbol
	
	#----------------------------------------------------------------------
	def loadHistoryData(self):
		"""载入历史数据"""
		host, port = loadMongoSetting()
		
		self.dbClient = pymongo.MongoClient(host, port)
		collection = self.dbClient[self.dbName][self.symbol]

		self.output(u'开始载入数据')
	  
		# 首先根据回测模式，确认要使用的数据类
		if self.mode == self.BAR_MODE:
			dataClass = CtaBarData
			func = self.newBar
		else:
			dataClass = CtaTickData
			func = self.newTick

		# 载入初始化需要用的数据
		flt = {'datetime':{'$gte':self.dataStartDate,
						   '$lt':self.strategyStartDate}}
		initCursor = collection.find(flt)
		
		# 将数据从查询指针中读取出，并生成列表
		for d in initCursor:
			data = dataClass()
			data.__dict__ = d
			self.initData.append(data)
		
		# 载入回测数据
		if not self.dataEndDate:
			flt = {'datetime':{'$gte':self.strategyStartDate}}   # 数据过滤条件
		else:
			flt = {'datetime':{'$gte':self.strategyStartDate,
							   '$lte':self.dataEndDate}}
		self.dbCursor = collection.find(flt)
		
		self.output(u'载入完成，数据量：%s' %(initCursor.count() + self.dbCursor.count()))
		
	#----------------------------------------------------------------------
	def runBacktesting(self):
		"""运行回测"""
		# 载入历史数据
		self.loadHistoryData()
		
		# 首先根据回测模式，确认要使用的数据类
		if self.mode == self.BAR_MODE:
			dataClass = CtaBarData
			func = self.newBar
		else:
			dataClass = CtaTickData
			func = self.newTick

		self.output(u'开始回测')
		
		self.strategy.inited = True
		self.strategy.onInit()
		self.output(u'策略初始化完成')
		
		self.strategy.trading = True
		self.strategy.onStart()
		self.output(u'策略启动完成')
		
		self.output(u'开始回放数据')

		for d in self.dbCursor:
			data = dataClass()
			data.__dict__ = d
			func(data)
			
		self.output(u'数据回放结束')
		
	#----------------------------------------------------------------------
	def newBar(self, bar):
		"""新的K线"""
		self.bar = bar
		self.dt = bar.datetime
		self.crossLimitOrder()	  # 先撮合限价单
		self.crossStopOrder()	   # 再撮合停止单
		self.strategy.onBar(bar)	# 推送K线到策略中
	
	#----------------------------------------------------------------------
	def newTick(self, tick):
		"""新的Tick"""
		self.tick = tick
		self.dt = tick.datetime
		self.crossLimitOrder()
		self.crossStopOrder()
		self.strategy.onTick(tick)
		
	#----------------------------------------------------------------------
	def initStrategy(self, strategyClass, setting=None):
		"""
		初始化策略
		setting是策略的参数设置，如果使用类中写好的默认设置则可以不传该参数
		"""
		self.strategy = strategyClass(self, setting)
		self.strategy.name = self.strategy.className
		
	#----------------------------------------------------------------------
	def sendOrder(self, vtSymbol, orderType, price, volume, strategy):
		"""发单"""
		self.limitOrderCount += 1
		orderID = str(self.limitOrderCount)
		
		order = VtOrderData()
		order.vtSymbol = vtSymbol
		order.price = price
		order.totalVolume = volume
		order.status = STATUS_NOTTRADED	 # 刚提交尚未成交
		order.orderID = orderID
		order.vtOrderID = orderID
		order.orderTime = str(self.dt)
		
		# CTA委托类型映射
		if orderType == CTAORDER_BUY:
			order.direction = DIRECTION_LONG
			order.offset = OFFSET_OPEN
		elif orderType == CTAORDER_SELL:
			order.direction = DIRECTION_SHORT
			order.offset = OFFSET_CLOSE
		elif orderType == CTAORDER_SHORT:
			order.direction = DIRECTION_SHORT
			order.offset = OFFSET_OPEN
		elif orderType == CTAORDER_COVER:
			order.direction = DIRECTION_LONG
			order.offset = OFFSET_CLOSE
		
		# 保存到限价单字典中
		self.workingLimitOrderDict[orderID] = order
		self.limitOrderDict[orderID] = order
		
		return orderID
	
	#----------------------------------------------------------------------
	def cancelOrder(self, vtOrderID):
		"""撤单"""
		if vtOrderID in self.workingLimitOrderDict:
			order = self.workingLimitOrderDict[vtOrderID]
			order.status = STATUS_CANCELLED
			order.cancelTime = str(self.dt)
			del self.workingLimitOrderDict[vtOrderID]
		
	#----------------------------------------------------------------------
	def sendStopOrder(self, vtSymbol, orderType, price, volume, strategy):
		"""发停止单（本地实现）"""
		self.stopOrderCount += 1
		stopOrderID = STOPORDERPREFIX + str(self.stopOrderCount)
		
		so = StopOrder()
		so.vtSymbol = vtSymbol
		so.price = price
		so.volume = volume
		so.strategy = strategy
		so.stopOrderID = stopOrderID
		so.status = STOPORDER_WAITING
		
		if orderType == CTAORDER_BUY:
			so.direction = DIRECTION_LONG
			so.offset = OFFSET_OPEN
		elif orderType == CTAORDER_SELL:
			so.direction = DIRECTION_SHORT
			so.offset = OFFSET_CLOSE
		elif orderType == CTAORDER_SHORT:
			so.direction = DIRECTION_SHORT
			so.offset = OFFSET_OPEN
		elif orderType == CTAORDER_COVER:
			so.direction = DIRECTION_LONG
			so.offset = OFFSET_CLOSE
		
		# 保存stopOrder对象到字典中
		self.stopOrderDict[stopOrderID] = so
		self.workingStopOrderDict[stopOrderID] = so
		
		return stopOrderID
	
	#----------------------------------------------------------------------
	def cancelStopOrder(self, stopOrderID):
		"""撤销停止单"""
		# 检查停止单是否存在
		if stopOrderID in self.workingStopOrderDict:
			so = self.workingStopOrderDict[stopOrderID]
			so.status = STOPORDER_CANCELLED
			del self.workingStopOrderDict[stopOrderID]
			
	#----------------------------------------------------------------------
	def crossLimitOrder(self):
		"""基于最新数据撮合限价单"""
		# 先确定会撮合成交的价格
		if self.mode == self.BAR_MODE:
			buyCrossPrice = self.bar.low		# 若买入方向限价单价格高于该价格，则会成交
			sellCrossPrice = self.bar.high	  # 若卖出方向限价单价格低于该价格，则会成交
			buyBestCrossPrice = self.bar.open   # 在当前时间点前发出的买入委托可能的最优成交价
			sellBestCrossPrice = self.bar.open  # 在当前时间点前发出的卖出委托可能的最优成交价
		else:
			buyCrossPrice = self.tick.askPrice1
			sellCrossPrice = self.tick.bidPrice1
			buyBestCrossPrice = self.tick.askPrice1
			sellBestCrossPrice = self.tick.bidPrice1
		
		# 遍历限价单字典中的所有限价单
		for orderID, order in self.workingLimitOrderDict.items():
			# 判断是否会成交
			buyCross = order.direction==DIRECTION_LONG and order.price>=buyCrossPrice
			sellCross = order.direction==DIRECTION_SHORT and order.price<=sellCrossPrice
			
			# 如果发生了成交
			if buyCross or sellCross:
				# 推送成交数据
				self.tradeCount += 1			# 成交编号自增1
				tradeID = str(self.tradeCount)
				trade = VtTradeData()
				trade.vtSymbol = order.vtSymbol
				trade.tradeID = tradeID
				trade.vtTradeID = tradeID
				trade.orderID = order.orderID
				trade.vtOrderID = order.orderID
				trade.direction = order.direction
				trade.offset = order.offset
				
				# 以买入为例：
				# 1. 假设当根K线的OHLC分别为：100, 125, 90, 110
				# 2. 假设在上一根K线结束(也是当前K线开始)的时刻，策略发出的委托为限价105
				# 3. 则在实际中的成交价会是100而不是105，因为委托发出时市场的最优价格是100
				if buyCross:
					trade.price = min(order.price, buyBestCrossPrice)
					self.strategy.pos += order.totalVolume
				else:
					trade.price = max(order.price, sellBestCrossPrice)
					self.strategy.pos -= order.totalVolume
				
				trade.volume = order.totalVolume
				trade.tradeTime = str(self.dt)
				trade.dt = self.dt
				self.strategy.onTrade(trade)
				
				self.tradeDict[tradeID] = trade
				
				# 推送委托数据
				order.tradedVolume = order.totalVolume
				order.status = STATUS_ALLTRADED
				self.strategy.onOrder(order)
				
				# 从字典中删除该限价单
				del self.workingLimitOrderDict[orderID]
				
	#----------------------------------------------------------------------
	def crossStopOrder(self):
		"""基于最新数据撮合停止单"""
		# 先确定会撮合成交的价格，这里和限价单规则相反
		if self.mode == self.BAR_MODE:
			buyCrossPrice = self.bar.high	# 若买入方向停止单价格低于该价格，则会成交
			sellCrossPrice = self.bar.low	# 若卖出方向限价单价格高于该价格，则会成交
			bestCrossPrice = self.bar.open   # 最优成交价，买入停止单不能低于，卖出停止单不能高于
		else:
			buyCrossPrice = self.tick.lastPrice
			sellCrossPrice = self.tick.lastPrice
			bestCrossPrice = self.tick.lastPrice
		
		# 遍历停止单字典中的所有停止单
		for stopOrderID, so in self.workingStopOrderDict.items():
			# 判断是否会成交
			buyCross = so.direction==DIRECTION_LONG and so.price<=buyCrossPrice
			sellCross = so.direction==DIRECTION_SHORT and so.price>=sellCrossPrice
			
			# 如果发生了成交
			if buyCross or sellCross:
				# 推送成交数据
				self.tradeCount += 1			# 成交编号自增1
				tradeID = str(self.tradeCount)
				trade = VtTradeData()
				trade.vtSymbol = so.vtSymbol
				trade.tradeID = tradeID
				trade.vtTradeID = tradeID
				
				if buyCross:
					self.strategy.pos += so.volume
					trade.price = max(bestCrossPrice, so.price)
				else:
					self.strategy.pos -= so.volume
					trade.price = min(bestCrossPrice, so.price)
				
				self.limitOrderCount += 1
				orderID = str(self.limitOrderCount)
				trade.orderID = orderID
				trade.vtOrderID = orderID
				
				trade.direction = so.direction
				trade.offset = so.offset
				trade.volume = so.volume
				trade.tradeTime = str(self.dt)
				trade.dt = self.dt
				self.strategy.onTrade(trade)
				
				self.tradeDict[tradeID] = trade
				
				# 推送委托数据
				so.status = STOPORDER_TRIGGERED
				
				order = VtOrderData()
				order.vtSymbol = so.vtSymbol
				order.symbol = so.vtSymbol
				order.orderID = orderID
				order.vtOrderID = orderID
				order.direction = so.direction
				order.offset = so.offset
				order.price = so.price
				order.totalVolume = so.volume
				order.tradedVolume = so.volume
				order.status = STATUS_ALLTRADED
				order.orderTime = trade.tradeTime
				self.strategy.onOrder(order)
				
				self.limitOrderDict[orderID] = order
				
				# 从字典中删除该限价单
				del self.workingStopOrderDict[stopOrderID]

	#----------------------------------------------------------------------
	def insertData(self, dbName, collectionName, data):
		"""考虑到回测中不允许向数据库插入数据，防止实盘交易中的一些代码出错"""
		pass
	
	#----------------------------------------------------------------------
	def loadBar(self, dbName, collectionName, startDate):
		"""直接返回初始化数据列表中的Bar"""
		return self.initData
	
	#----------------------------------------------------------------------
	def loadTick(self, dbName, collectionName, startDate):
		"""直接返回初始化数据列表中的Tick"""
		return self.initData
	
	#----------------------------------------------------------------------
	def writeCtaLog(self, content):
		"""记录日志"""
		log = str(self.dt) + ' ' + content
		self.logList.append(log)
		
	#----------------------------------------------------------------------
	def output(self, content):
		"""输出内容"""
		print str(datetime.now()) + "\t" + content
	
	#----------------------------------------------------------------------
	def calculateBacktestingResult(self):
		"""
		计算回测结果
		"""
		self.output(u'计算回测结果')
		
		# 首先基于回测后的成交记录，计算每笔交易的盈亏
		resultList = []			 # 交易结果列表
		
		longTrade = []			  # 未平仓的多头交易
		shortTrade = []			 # 未平仓的空头交易
		
		for trade in self.tradeDict.values():
			# 多头交易
			if trade.direction == DIRECTION_LONG:
				# 如果尚无空头交易
				if not shortTrade:
					longTrade.append(trade)
				# 当前多头交易为平空
				else:
					while True:
						entryTrade = shortTrade[0]
						exitTrade = trade
						
						# 清算开平仓交易
						closedVolume = min(exitTrade.volume, entryTrade.volume)
						result = TradingResult(entryTrade.price, entryTrade.dt,
											   exitTrade.price, exitTrade.dt,
											   -closedVolume, self.rate, self.slippage, self.size)
						resultList.append(result)
						
						# 计算未清算部分
						entryTrade.volume -= closedVolume
						exitTrade.volume -= closedVolume
						
						# 如果开仓交易已经全部清算，则从列表中移除
						if not entryTrade.volume:
							shortTrade.pop(0)
						
						# 如果平仓交易已经全部清算，则退出循环
						if not exitTrade.volume:
							break
						
						# 如果平仓交易未全部清算，
						if exitTrade.volume:
							# 且开仓交易已经全部清算完，则平仓交易剩余的部分
							# 等于新的反向开仓交易，添加到队列中
							if not shortTrade:
								longTrade.append(exitTrade)
								break
							# 如果开仓交易还有剩余，则进入下一轮循环
							else:
								pass
						
			# 空头交易
			else:
				# 如果尚无多头交易
				if not longTrade:
					shortTrade.append(trade)
				# 当前空头交易为平多
				else:
					while True:
						entryTrade = longTrade[0]
						exitTrade = trade
						
						# 清算开平仓交易
						closedVolume = min(exitTrade.volume, entryTrade.volume)
						result = TradingResult(entryTrade.price, entryTrade.dt,
											   exitTrade.price, exitTrade.dt,
											   closedVolume, self.rate, self.slippage, self.size)
						resultList.append(result)
						
						# 计算未清算部分
						entryTrade.volume -= closedVolume
						exitTrade.volume -= closedVolume
						
						# 如果开仓交易已经全部清算，则从列表中移除
						if not entryTrade.volume:
							longTrade.pop(0)
						
						# 如果平仓交易已经全部清算，则退出循环
						if not exitTrade.volume:
							break
						
						# 如果平仓交易未全部清算，
						if exitTrade.volume:
							# 且开仓交易已经全部清算完，则平仓交易剩余的部分
							# 等于新的反向开仓交易，添加到队列中
							if not longTrade:
								shortTrade.append(exitTrade)
								break
							# 如果开仓交易还有剩余，则进入下一轮循环
							else:
								pass
					
		# 检查是否有交易
		if not resultList:
			self.output(u'无交易结果')
			return {}
		
		# 然后基于每笔交易的结果，我们可以计算具体的盈亏曲线和最大回撤等
		capital = 0			 # 资金
		maxCapital = 0		  # 资金最高净值
		drawdown = 0			# 回撤
		
		totalResult = 0		 # 总成交数量
		totalTurnover = 0	   # 总成交金额（合约面值）
		totalCommission = 0	 # 总手续费
		totalSlippage = 0	   # 总滑点
		
		timeList = []		   # 时间序列
		pnlList = []			# 每笔盈亏序列
		capitalList = []		# 盈亏汇总的时间序列
		drawdownList = []	   # 回撤的时间序列
		
		winningResult = 0	   # 盈利次数
		losingResult = 0		# 亏损次数
		totalWinning = 0		# 总盈利金额
		totalLosing = 0		 # 总亏损金额
		
		for result in resultList:
			capital += result.pnl
			maxCapital = max(capital, maxCapital)
			drawdown = capital - maxCapital
			
			pnlList.append(result.pnl)
			timeList.append(result.exitDt)	  # 交易的时间戳使用平仓时间
			capitalList.append(capital)
			drawdownList.append(drawdown)
			
			totalResult += 1
			totalTurnover += result.turnover
			totalCommission += result.commission
			totalSlippage += result.slippage
			
			if result.pnl >= 0:
				winningResult += 1
				totalWinning += result.pnl
			else:
				losingResult += 1
				totalLosing += result.pnl
				
		# 计算盈亏相关数据
		winningRate = winningResult/totalResult*100	 # 胜率
		averageWinning = totalWinning/winningResult	 # 平均每笔盈利
		averageLosing = totalLosing/losingResult		# 平均每笔亏损
		profitLossRatio = -averageWinning/averageLosing # 盈亏比

		# 返回回测结果
		d = {}
		d['capital'] = capital
		d['maxCapital'] = maxCapital
		d['drawdown'] = drawdown
		d['totalResult'] = totalResult
		d['totalTurnover'] = totalTurnover
		d['totalCommission'] = totalCommission
		d['totalSlippage'] = totalSlippage
		d['timeList'] = timeList
		d['pnlList'] = pnlList
		d['capitalList'] = capitalList
		d['drawdownList'] = drawdownList
		d['winningRate'] = winningRate
		d['averageWinning'] = averageWinning
		d['averageLosing'] = averageLosing
		d['profitLossRatio'] = profitLossRatio
		
		return d
		
	#----------------------------------------------------------------------
	def showBacktestingResult(self):
		"""显示回测结果"""
		d = self.calculateBacktestingResult()
		
		# 输出
		self.output('-' * 30)
		self.output(u'第一笔交易：\t%s' % d['timeList'][0])
		self.output(u'最后一笔交易：\t%s' % d['timeList'][-1])
		
		self.output(u'总交易次数：\t%s' % formatNumber(d['totalResult']))
		self.output(u'总盈亏：\t%s' % formatNumber(d['capital']))
		self.output(u'最大回撤: \t%s' % formatNumber(min(d['drawdownList'])))
		
		self.output(u'平均每笔盈利：\t%s' %formatNumber(d['capital']/d['totalResult']))
		self.output(u'平均每笔滑点：\t%s' %formatNumber(d['totalSlippage']/d['totalResult']))
		self.output(u'平均每笔佣金：\t%s' %formatNumber(d['totalCommission']/d['totalResult']))
		
		self.output(u'胜率\t\t%s%%' %formatNumber(d['winningRate']))
		self.output(u'平均每笔盈利\t%s' %formatNumber(d['averageWinning']))
		self.output(u'平均每笔亏损\t%s' %formatNumber(d['averageLosing']))
		self.output(u'盈亏比：\t%s' %formatNumber(d['profitLossRatio']))
	
		# 绘图
		import matplotlib.pyplot as plt
		
		pCapital = plt.subplot(3, 1, 1)
		pCapital.set_ylabel("capital")
		pCapital.plot(d['capitalList'])
		
		pDD = plt.subplot(3, 1, 2)
		pDD.set_ylabel("DD")
		pDD.bar(range(len(d['drawdownList'])), d['drawdownList'])
		
		pPnl = plt.subplot(3, 1, 3)
		pPnl.set_ylabel("pnl")
		pPnl.hist(d['pnlList'], bins=50)
		
		plt.show()
	
	#----------------------------------------------------------------------
	def putStrategyEvent(self, name):
		"""发送策略更新事件，回测中忽略"""
		pass

	#----------------------------------------------------------------------
	def setSlippage(self, slippage):
		"""设置滑点点数"""
		self.slippage = slippage
		
	#----------------------------------------------------------------------
	def setSize(self, size):
		"""设置合约大小"""
		self.size = size
		
	#----------------------------------------------------------------------
	def setRate(self, rate):
		"""设置佣金比例"""
		self.rate = rate

	#----------------------------------------------------------------------
	def runOptimization(self, strategyClass, optimizationSetting):
		"""优化参数"""
		# 获取优化设置
		settingList = optimizationSetting.generateSetting()
		targetName = optimizationSetting.optimizeTarget
		
		# 检查参数设置问题
		if not settingList or not targetName:
			self.output(u'优化设置有问题，请检查')
		
		# 遍历优化
		resultList = []
		for setting in settingList:
			self.clearBacktestingResult()
			self.output('-' * 30)
			self.output('setting: %s' %str(setting))
			self.initStrategy(strategyClass, setting)
			self.runBacktesting()
			d = self.calculateBacktestingResult()
			try:
				targetValue = d[targetName]
			except KeyError:
				targetValue = 0
			resultList.append(([str(setting)], targetValue))
		
		# 显示结果
		resultList.sort(reverse=True, key=lambda result:result[1])
		self.output('-' * 30)
		self.output(u'优化结果：')
		for result in resultList:
			self.output(u'%s: %s' %(result[0], result[1]))
			
	#----------------------------------------------------------------------
	def clearBacktestingResult(self):
		"""清空之前回测的结果"""
		# 清空限价单相关
		self.limitOrderCount = 0
		self.limitOrderDict.clear()
		self.workingLimitOrderDict.clear()
		
		# 清空停止单相关
		self.stopOrderCount = 0
		self.stopOrderDict.clear()
		self.workingStopOrderDict.clear()
		
		# 清空成交相关
		self.tradeCount = 0
		self.tradeDict.clear()
		

########################################################################
class TradingResult(object):
	"""每笔交易的结果"""

	#----------------------------------------------------------------------
	def __init__(self, entryPrice, entryDt, exitPrice,
				 exitDt, volume, rate, slippage, size):
		"""Constructor"""
		self.entryPrice = entryPrice	# 开仓价格
		self.exitPrice = exitPrice	  # 平仓价格
		
		self.entryDt = entryDt		  # 开仓时间datetime
		self.exitDt = exitDt			# 平仓时间
		
		self.volume = volume	# 交易数量（+/-代表方向）
		
		self.turnover = (self.entryPrice+self.exitPrice)*size*abs(volume)   # 成交金额
		self.commission = self.turnover*rate								# 手续费成本
		self.slippage = slippage*2*size*abs(volume)						 # 滑点成本
		self.pnl = ((self.exitPrice - self.entryPrice) * volume * size
					- self.commission - self.slippage)					  # 净盈亏


########################################################################
class OptimizationSetting(object):
	"""优化设置"""

	#----------------------------------------------------------------------
	def __init__(self):
		"""Constructor"""
		self.paramDict = OrderedDict()
		
		self.optimizeTarget = ''		# 优化目标字段
		
	#----------------------------------------------------------------------
	def addParameter(self, name, start, end, step):
		"""增加优化参数"""
		if end <= start:
			print u'参数起始点必须小于终止点'
			return
		
		if step <= 0:
			print u'参数布进必须大于0'
			return
		
		l = []
		param = start
		
		while param <= end:
			l.append(param)
			param += step
		
		self.paramDict[name] = l
		
	#----------------------------------------------------------------------
	def generateSetting(self):
		"""生成优化参数组合"""
		# 参数名的列表
		nameList = self.paramDict.keys()
		paramList = self.paramDict.values()
		
		# 使用迭代工具生产参数对组合
		productList = list(product(*paramList))
		
		# 把参数对组合打包到一个个字典组成的列表中
		settingList = []
		for p in productList:
			d = dict(zip(nameList, p))
			settingList.append(d)
	
		return settingList
	
	#----------------------------------------------------------------------
	def setOptimizeTarget(self, target):
		"""设置优化目标字段"""
		self.optimizeTarget = target


#----------------------------------------------------------------------
def formatNumber(n):
	"""格式化数字到字符串"""
	n = round(n, 2)		 # 保留两位小数
	return format(n, ',')   # 加上千分符
	
	


if __name__ == '__main__':
	# 以下内容是一段回测脚本的演示，用户可以根据自己的需求修改
	# 建议使用ipython notebook或者spyder来做回测
	# 同样可以在命令模式下进行回测（一行一行输入运行）
	from ctaDemo import *
	
	# 创建回测引擎
	engine = BacktestingEngine()
	
	# 设置引擎的回测模式为K线
	engine.setBacktestingMode(engine.BAR_MODE)

	# 设置回测用的数据起始日期
	engine.setStartDate('20110101')
	
	# 载入历史数据到引擎中
	engine.setDatabase(MINUTE_DB_NAME, 'IF0000')
	
	# 设置产品相关参数
	engine.setSlippage(0.2)	 # 股指1跳
	engine.setRate(0.3/10000)   # 万0.3
	engine.setSize(300)		 # 股指合约大小
	
	# 在引擎中创建策略对象
	engine.initStrategy(DoubleEmaDemo, {})
	
	# 开始跑回测
	engine.runBacktesting()
	
	# 显示回测结果
	# spyder或者ipython notebook中运行时，会弹出盈亏曲线图
	# 直接在cmd中回测则只会打印一些回测数值
	engine.showBacktestingResult()
	
	